import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.nn import MSELoss

from metrics import test_score_model
from configs import DEVICE
from configs import args


def train_epoch(model: nn.Module, train_dataloader: DataLoader, optimizer, scheduler):
    model.train()
    tr_loss = 0
    nb_tr_examples, nb_tr_steps = 0, 0
    for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
        batch = tuple(t.to(DEVICE) for t in batch)
        input_ids, visual, acoustic, label_ids = batch
        visual = torch.squeeze(visual, 1)
        acoustic = torch.squeeze(acoustic, 1)

        visual_norm = (visual - visual.min()) / (visual.max() - visual.min())
        acoustic_norm = (acoustic - acoustic.min()) / (acoustic.max() - acoustic.min())
        logits_o, logits_co, loss_c, unimodal_mse_loss, align_loss, KL  = model(
            input_ids,
            visual_norm,
            acoustic_norm,
            label_ids
        )
        loss_fct = MSELoss()
        loss_o = loss_fct(logits_o.view(-1), label_ids.view(-1)) 
        loss_co = loss_fct(logits_co.view(-1), label_ids.view(-1))
        loss = loss_o + args.p_alpha * (unimodal_mse_loss) + KL + args.p_lambda1 * (loss_c + align_loss) + args.p_lambda2 * loss_co

        if args.gradient_accumulation_step > 1:
            loss = loss / args.gradient_accumulation_step

        loss.backward()

        tr_loss += loss.item()
        nb_tr_steps += 1

        if (step + 1) % args.gradient_accumulation_step == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    return tr_loss / nb_tr_steps


def eval_epoch(model: nn.Module, dev_dataloader: DataLoader):
    model.eval()
    dev_loss = 0
    nb_dev_examples, nb_dev_steps = 0, 0
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dev_dataloader, desc="Iteration")):
            batch = tuple(t.to(DEVICE) for t in batch)
            input_ids, visual, acoustic, label_ids = batch
            visual = torch.squeeze(visual, 1)
            acoustic = torch.squeeze(acoustic, 1)

            visual_norm = (visual - visual.min()) / (visual.max() - visual.min())
            acoustic_norm = (acoustic - acoustic.min()) / (acoustic.max() - acoustic.min())

            logits_o, logits_co, loss_c, unimodal_mse_loss, align_loss, KL = model(
                input_ids,
                visual_norm,
                acoustic_norm,
                label_ids
            )
            loss_fct = MSELoss()
            loss = loss_fct(logits_o.view(-1), label_ids.view(-1))

            if args.gradient_accumulation_step > 1:
                loss = loss / args.gradient_accumulation_step

            dev_loss += loss.item()
            nb_dev_steps += 1

    return dev_loss / nb_dev_steps


def test_epoch(model: nn.Module, test_dataloader: DataLoader):
    model.eval()
    preds = []
    labels = []

    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            batch = tuple(t.to(DEVICE) for t in batch)

            input_ids, visual, acoustic, label_ids = batch
            visual = torch.squeeze(visual, 1)
            acoustic = torch.squeeze(acoustic, 1)

            visual_norm = (visual - visual.min()) / (visual.max() - visual.min())
            acoustic_norm = (acoustic - acoustic.min()) / (acoustic.max() - acoustic.min())

            logits_o, logits_co, loss_c, unimodal_mse_loss, align_loss, KL = model(
                input_ids,
                visual_norm,
                acoustic_norm,
                label_ids
            )

            logits = logits_o.detach().cpu().numpy()
            label_ids = label_ids.detach().cpu().numpy()

            logits = np.squeeze(logits).tolist()
            label_ids = np.squeeze(label_ids).tolist()

            preds.extend(logits)
            labels.extend(label_ids)

        preds = np.array(preds)
        labels = np.array(labels)

    return preds, labels


def train(
        model,
        train_dataloader,
        validation_dataloader,
        test_data_loader,
        optimizer,
        scheduler,
):
    valid_losses = []
    test_accuracies = []
    mae_list = []
    corr_list = []
    f1_list = []
    best_valid_loss = float('inf')
    for epoch_i in range(int(args.n_epochs)):
        train_loss = train_epoch(model, train_dataloader, optimizer, scheduler)
        valid_loss = eval_epoch(model, validation_dataloader)

        print(
            "TRAIN: epoch:{}, train_loss:{}, valid_loss:{}".format(
                epoch_i + 1, train_loss, valid_loss
            )
        )

        preds, labels = test_epoch(model, test_data_loader)
        acc7, acc2, f1_score, acc2_non, f1_score_non, mae, corr = test_score_model(preds, labels)

        if best_valid_loss > valid_loss:
            best_valid_loss = valid_loss
            best_acc7 = acc7
            best_acc2 = acc2
            best_acc2_non = acc2_non
            best_f1_score = f1_score
            best_f1_score_non = f1_score_non
            best_mae = mae
            best_corr = corr

            print("\n✅ TEST [BEST VALID LOSS]: "
                "Acc7: {:.2f}, Acc2: {:.2f}/{:.2f}, F1: {:.2f}/{:.2f}, MAE: {:.3f}, Corr: {:.3f}\n".format(
                    acc7 * 100, acc2 * 100, acc2_non*100, f1_score*100, f1_score_non*100, mae, corr))
        else:
            print("\nTEST: "
                "Acc7: {:.2f}, Acc2: {:.2f}/{:.2f}, F1: {:.2f}/{:.2f}, MAE: {:.3f}, Corr: {:.3f}\n".format(
                    acc7 * 100, acc2 * 100, acc2_non*100, f1_score*100, f1_score_non*100, mae, corr))

    print("\nResults: "
        "Acc7: {:.2f}, Acc2: {:.2f}/{:.2f}, F1: {:.2f}/{:.2f}, MAE: {:.3f}, Corr: {:.3f}\n".format(
            best_acc7 * 100, best_acc2 * 100, best_acc2_non*100, best_f1_score*100, best_f1_score_non*100, best_mae, best_corr))



